{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "49d410bc",
   "metadata": {},
   "source": [
    "# 1. Implementing the GCN Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "d79f3531",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree"
   ]
  },
  {
   "attachments": {
    "image.png": {
     "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEjCAYAAAD0aHdQAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAADb2SURBVHhe7d0HsF1l1T7wbe+KBVDBAqiAKBCaQOjSQgghQMCI0pTqgIoziow4nw6jA4iI4AwSDL2FIJ1A6E0EgihIURALINh7L9+f3/vnJSf3u2Wfcm/2Oe96Zvbcds4++56zn3et9azyPud/n0YVCAQGHs995msgEBhwBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQZA8ECkGQPRAoBEH2QKAQBNkDgUIQWzYHBhL//e9/q7/85S/VL3/5y+rnP/959Zvf/Kb685//XP3jH/9If3v+859fveIVr6he/epXV8suu2z1xje+sVpqqaWq5z53cO1fkD0wEEDgf/3rX9Uf//jH6pFHHqkefvjh6ic/+Un11FNPVb/73e8S0f/+97+nx7jln/e851UvfvGLq5e//OXVq171quq1r31t9aY3valaZZVVqre//e3VMsssU73whS+snvOc56RjEBBkD/Q1/vOf/yQi//rXv64ef/zx6tFHH63uvffe6oEHHkhWHVFZ8Je85CWJvEjud56H+H/729+SB+Ar4q+22mrV6quvXq288sqJ/Ky+57/gBS/oe9IH2QN9CbctwnLPH3zwwerb3/52dfvttyeLzkq/+c1vrt72trelw/dLL710suCsOVc9ewG/+tWvqieeeKL68Y9/nJ77ox/9KC0SHrvBBhtUm2yySfXud787kd5z+5nwQfZAXwJZWfJ58+ZV1157bXLX3/rWt1abb7559a53vat6y1veUr3mNa9JBBWfI7ijlaxcf4dF45///Gf1pz/9KZ2TV3DnnXdWd911V3rOpptuWk2fPr1ab731qhe96EXPPLv/EGQP9B3++te/Vg899FD11a9+tbr//vuTu80Cs8RveMMbqle+8pXVS1/60rZc7+wpiOuR/he/+EXyGC677LJk8Zdffvlq6tSp1U477dS3Ql6QPdA3QEYkvPvuu6v58+dXCxcuTILae9/73mrjjTeuVlhhhZ652V7rD3/4QwoPrrzyymTtkRzht9xyy+r1r39931n5IHugL+A2RfRbb721uuKKK5IIt+GGG1YzZsyo1lhjjeSyj0c8jfTf+973qquvvjoR/9///ne14447VltssUXyKIh+/YIge6DxcIsiGcKde+651WOPPZYs+R577JHEN0r7eEMsf80111TnnHNOsvgHHnhgsvDLLbdc37j0UUEXaDwQ/cknn6zOPvvsVCCzzTbbVIccckhy2wlwEwFu+1ZbbVXts88+KU03e/bstPjQD/oFQfZAo0F1/+lPf1odf/zxifDvec97qq233rp63ete92zRy0SAoq/QZvLkydVee+2VfnfzzTensIKS3w8OcpA90FhIi0mpIRUrypJz39/xjnek4piJhsWF2j9t2rSkF/z2t7+tbrrpppSnd61NR5A90Fhwl3/wgx9U1113XSITohPjpNbahdJZcXe3brd0nnJaqjziS8/Jx/NAmo4ge6CxoL7fc889iVDEMNYUwUZy3S0IGl3kyi0UrQcFXwrthz/8Yfp7N26311977bWrddZZJ+kJzi0333Tr/rz/eRrPfB8INAZIJJ/OTVb7fthhh1UrrrjiiLntnBeXkrM45NLXfIj3v/Wtb6XFQzktFb2bUEAM7zV5C7fddlu11lprJR1hogTDThCWPdBIsMYI+/vf/75af/31U7w+WorN4uCxeYG4/vrrFzvuu+++1A2XD0TtxrpbKKT9NM5w7ZXtyhQ0WagLsgcaCV1sP/vZz5LLzF3WkTZaPhv5dKepqONir7vuuosdGluQkwXOAl+3Sr5Cntxow2uQNRAiNBVRVBNoJFjnM844I1lKOXVtp2MVr4iZCWXDxc6XX355amVFdIQXDnRLdtdGQDz99NOrOXPmVCJi1XV0hSYiLHugkUAi5JTbdtQhpsVAeszBtfaV6+9Q3oqIvIRe9aY7h/h/0qRJ6Wdtslpmm4oge6BxYDHF65R1ROee14Hnid3l5sXl8t9icxAGEPD8Xr96r5RzaUCpOIIdEVDuvanOcpA90CggCpIjLIiL61TKeY5cOnf6C1/4QnXkkUdWp512WvX9738/ufYLFixIVXjHHXdcSsH1Ki/OS8hz7GQNmpyCC7IHGgVE4b4jDSKx6nVSZFx4PeyGVhgeKUd/ww03pIIcqTGHqTbaVFnjXrjx4HWl28TpvAjXzrtoIoLsgUYhk129ObLXHQXlsazrtttuW+2+++4pNrdg6FS76KKL0mM222yzaubMmWnIBbe7V7AYybEDD6Op1XRB9kCjwI1HdFYSiZC4XRhPpWFFSkwRjX50FXiIbrSUqTNDlX2LjMVB6izH+XVhMXrZy16WvmfVw40PBNoE4ncqdhHNFLyw4GuuuWYiPjd/OHgNBTlifHXuRk+3+7qZ4L0KD8YDQfZAo4AseUgkC9uJS4yoxDLqe67E8/Nw8FiqvVlzl156aRL5uOLtANHz+V13HY1hSSDIHmgUuNfSZAiP6LrU6rrFiOtQfSdPT5DjshtOqbrN75HSApAf6zH+ztVH+nZeL8OiJHtgoZLT7yT0mAgE2QONArKLf6nm4l+z3evE0OJtHW2mzep/V+AiR2/YBRc9z5X3GBYf0cXnHofkUnyscrtuuGuzQMixI7mFqpfiXy8RZA80CsiGNHm2G8tbpyUVaZWtnnDCCanMVm7erHcjpuXBr7rqqjS/7o477kjkdD7pOTvI5Md2MsuOy69Ix3Wqv5fWGyr+NQVB9kAjYfsl9evcY9Z9LLJz0W+55ZbqkksuSc+1WYQGGE0xNo1Q2cY9J9RZSHzPdXdezTOd5t55CVpqnc95eQhNRZA90EhQ0xXU5O2ZxiL7O9/5zuqII45IG0eYOqvhhfqO3J/61KdS5dzBBx+cFgDn4u7rqrOgIKjcPpecF8Fa1wkdnMf1aavljei3t81UUxHDKwKNBCtLZNMjzurab220WBixzXFHcpNgueQ5JFDdph/e351LfG9SrVJalplXYOgFz8BrOBfNwDEaLBC8A+W3zmu3GC2vocYHAm0A0RTHII5915BytDJUcTIrjeStZEN4BCb65bZWLre43WN9NclGzE+lF8eLwRF5LLDqFiShhl1peBHi/6YiyB5oJKTeWGllqHkOHQKO5c7XAcHuox/9aHX44YdXhx56aPWRj3wkuf7Ue5s4Oiw0o8HCIycvHABWval97BlB9kAjwVIjHHGNpf7617+e2lN7UXfufBYRgpo8vANRLTDid4uB70eDdJ5FSJihFNc5eA5NRpA90FgYDKHUlaLOlTe9RnFMu0UvwwHhHVx8h7SZ4RaEPsLgaMo88c61fOc730nPs0ONOH+05zQBQfZAY8FSErxYTsQ3543yLY3WayiG2WGHHVIaTnw/EnEp9d/97neTKMe6m2PvaGrVXCuC7IHGAuG426bLKnpRuGK7Jcp5nUKbdkBY0zgjbh/OHfdaBDxC3gUXXJCU/LwvvOc0tZCmFUH2QKOBRPrU99133xS/E8Q0rah8Q/heuPR1oO1WCKES7+KLL07XZKNHeftW9b/JiDx7oPFg4VlbwpnUGFde84qJNFJ0OaU2XmDVpdh0xV144YVJ2Ntvv/3SdlTc/6bH6hlB9kDjgUwOhSvILTeuqYU6j+h+N5ao1gmQnEW3ycT555+f6uotMNJ0LDodoV+sOgTZA30BRBZX59QYIoqbH3vssVQRR1G3GCBfXhw6Re5ks7WT3WO/+c1vpp546TX5dBYd0b1mPyE2iQj0HVhbJD/11FNT6yolXDmtjR/VpxP1ED8PwagDNKC0WzgIgermLSYGVvod8W7KlCmpUq4f0mzDIcge6EuwvlJw1157bbK+hDuW36BJuXkpNJaYe28xYPGJfdnqu+2Je87jIPYpexWbU/ul1whyFHd5dCTXTddPbvtQBNkDfQtkZXV1xSGooRVSc37P1UfUlVZaKXXQibX9jlVGet6BWXMILp3GTUduFXEWA6W68vsGVBpPrXimyXXvdRBkD/Q1EBtx9bxT6hHWjq3Ij8Tq6bnzhDwWHtHd8g4W3XPlz32vkUWBjJBAMU9eIDy/ny16RpA9MDBwKyMvwlPq5eKNi2LBeQBENwuCBhbEl7dXE0/NF+ez4Nx/pPe7fozLR0OQPTCQcFtrmjFgUiuq0laHenZFOSz9Pvvsk+J7Ayeo6/1Q8toNooIuMJBglXOqjrBm8KTWVW66Pnau+aqrrppUdsMuBp3oEGQPFAMLQFbkWX3xfkmObZA9UAxayU6QKy2CDbIHigGSOyDIHggUgtJceAiyB4pBq2UvEUH2QJEokfhB9kAx4LZn1z0LdSUhyD4GxHa6ofIWwCqxVGn5faC/gOg5VtcNVxrZo4LuaXgLqLN5i2B11qqtcomlQ/20WeFuEgUZGiocuqo0STj8rDijxBupH6CzzSCKY489Nn1ORx55ZDVp0qSBqHuvg6LJnknOcrPadgJRN23kkWYKtdV+h+g5xsvPQXiDBvVPq7E2glhttdJLVVu5eSJI3xz4HPWnf/nLX04TZD/3uc+lclkufQkoluz+bSTX4rhw4cI0ckinlDpqddM6nhBX15NBCEosWWzW3U3DrddgYZ8wLZa8AM0Uq6++eirNNBE1bzwQaAZ0wOl/t8mjBdlGkMpng+wDDhbblJMFCxakzflYYB1P6qj1QSO6zieuuRrr7Jp7u1h2cTv3Pm8MaISRjf51W7mpnEc/9AYbbJDOV4qr2GRYkH3edno1NPLTn/50qo0Psg8giDNIaiM/K7xNB1hyH/xaa62VrLLdPnMPs7huJFfc2+Z84nweghhfayXvQGeViSfcfKOStthii+TqO18pN1YTYWE2CvrEE09MC/vHP/7x1AxTSqhVzMBJ1tiH/cADD6SRwNdcc01yxU0JnTVrVjVt2rRkjRGdy56JOdKN4Pf+ngU7XoA+aDeP2J3VN8eMBmAhMHI4ewlh5ZcMfCb0mLvuuivtIyfc0tpaCtmLMTNcOO76KaecUs2dOzfF1x/60IfSbp6mhRJsuoWbRnxvlNFhhx1WHXDAASluNyNNnGi/MnPTLDyBiYf3nSfmq8+7tKxJEWTnbttYYM6cOdVtt91W7bLLLtUhhxxSTZ06NfUy9/oDdz4pOe47wvMaEH327NkpdOBhBCYexNVcI8ET472VhIEmexbTxOfnnXdemklmp84ZM2akGFpufLxcaudl5aXkLC72AacPmJJC+Q/rPvFg1fOWUdT4IPuAANF9uEYMi9Gp5AYJGvIvNmd5x1ssc36vY7qpjQknT56cZqLJ9RpVjPAF6aNLHO4HaVNkF6v3+7TYdjGwZEck8bF42bB/ijurTnEnqE1krOamkn7jURCGLEA2B5Tj51oGJgZc+GzZpVbDsg8IsvI6f/78JMZtv/32aXtdxGuH6Mjo6MYCez1uvQIOuXeLzeWXX57idxmBsO4TA2Rn2b3f7omw7AMCW/hIryme2W677dL2uu1Ws7EAcvLSZ91aYISX0rPgbLTRRsnC2Odbes5NGBhfILj33KFCkhpfWgp0IMlu9VbXfsMNN6Q0GJGMGDeSRUdq4pmiGORuPXgHxL0vfvGLqV6eJfb4TuD15fHF7ttuu23K+TunhSms+/jC+8vbE7e7F7jw7Xh4g4CBJDtrqXBCikuNunh9pFXcTcBqS43ZrfOss8569jj77LNTGKDqyr7cF110UdoeyA3TKeR2bS1krDFPQ8Wd2L3TBSRQDzQcC7XPjhLfbjg3CBg4svtQCWDy6QpaCHI+3NE+WIR3IxD01Lnng8XlJTg0vQgJlMZ2S0w5XuWasgPKai0gIdSNL9wXeRaBeL1Eyz5w5bKaUIhfVHixMWGO6zzSB+v3rL69vdSx260zH4pilL+6QaTRFOKwymLvbm4Uz7VgiB9vuummtPlgPm9gfOC95u0J7/Q/uDdCje9zKJzhxqtDF68TYsYipr9n0ab1cA5uH9IfffTRSU33u26InsGFt8ComXe9jsD4gWU3lMTnqX15vGssmoiB+49t0q9wBYm48EhcBwjcergZssU34EB+vJeun5hRqS6XkhtP8Q+MH3hnyI703vcge59D7I3oPlRxui60Oq6aG4B7p7Lt3HPPTUKd3T793jkffPDBJNbNmzeva4Eug1AnvFDcIRNgAIbXCowPdB7SZdwPJXW6tWKgyE7kIqqJh5dddtmUYhlrBUcwpJZ2u/3226vzzz8/dcaJ+ank1HL17MhOC7AI9ILsrktIwLK7CRG+F+cNDA+fJT2Hx6dnIcjex0BaIowPlNBFia/zgXqMQw27PnQLBIVcS6oKN0004mluvL+7WXpRjOE1nSfvA56HWwbGBz5DWRXvt8+6zr0xaBgoyy4uc4iHrd51waUW3++www6pUUZHnIKa0047LVlgfe/HHHNM9YlPfCIpub0qs8zW3Y3nui1WgfHBww8/nBZTnlSQvc/BsnPfHUjUTlrFB4/AbgKpNtNrEI+1RW6/owGw+nnggRhQLH/LLbek2XMe2y6ch3X3VSjhCPQW2eOTpRHmGRVWaopzoCw70jhyHN4uPBepdahZMBTQUPNbbw7nVUJ7ySWXVCeffHIaiHH88cenn3kD7RbceLzr9XqOQG+B4MRPBVE+R1pOaRNqMgbm7vLhseYOQlcnlhaRldh6rhJbXwl+dACEzOkbU28cRB+PsfEApd602nbibkRndXx13b0KDwKLYMHmgfncZD9KTbvBQP3XClXEwAhE3UbQseAxFgfPURJrTp2bQ5cc4krJSechvHMSerjwqu2IeF/60pdSlZ6/mSqr5LYOsvfhJgTX7Qj0Fj5XnwtxjlUfjzFk/YKBs+ziauTN+fCxCI/QYm6llNJtJtqYZEOUo+jrSrvxxhuT1WbBnc9CoKpOjtxjlF6KBZG3bvjgPDwFngOLLnyoWwAUqIe8mFrAvbe8NTn2UjFQlp17ZuUmtLHSqulGazBBOPGcQZC2AlJQY4FAXqk2Qp2GmLlz56Ycu8YV6TcxPXJmcc2CobyW5eAq1oHr4iHk7aV4DFJ9fnY+15Hj+UBn8L76fN0HCqx8biUvqAMXvPhAkc6KbhyVmG00sKq5bFXf+9prr51mxlkwtt5662qzzTar1llnnZSaUzrLcyD0EHkQ0qLCG7A4KKuVh68Dz1WgYzFhgZTLaqu1W8k3vvGNNDePNyFlZAHIsX2gPugpVHjek88W4UsWQQeu682HqShGmyvrO1rzit8hu8YIBGfRbQckR89qWwBYcg01tnFqrbxCPAuKARRez4LgsBjUgecKHVTtuQktOIhv4VDN5yZ1+B7ZPZ6lYuktNI7AyPD52JJLalQotttuuz17L5SKgSM7q8t1ywqs4RWqpkYihzgfqe3GKu4m8iG0w41hIUDEoSOtuNqssZuJRUd0j4PhFpZWICwCS9exOkZN77rrrmlRESpYVAh+FiwLAoHJgfiEQv9LVu4tbiVbq5FAD+EV0VsQ33usZiIs+wAhf5iIfuuttybyIKyYuldAVhb4nnvuSW47j4BY56bino9FQI9xI37ta19L4cHmm2+ehllYbIQDLJB++m222SYtIhYP1okXYGrOlVdemWr2syDpfwvCLw6Los+HqGqbJ7v+WNRLxkDu9cYiE8BuvvnmJMiw2sjYC0KwGGa+q5nnPVB46QIUfYdFhkLvtYaz8BYERDfuitVWomtTSSFHds9ds/+BbkB/MNiCh8KD8HrOLd5HeHG966AdOLfwwzl8LRmyMd4biySvyWJasgsPA0l2Lq4bnwornka+Xlh35+R2G3nFrc7iWo6xLQReKwtBw5Gde24QZh5xbZY893JomAAI6/dZO/A4C5evQgaWnQUTm6rqc6gJyAVFQhqPKRGsOk/IZ8KFt2iW+l5kDCTZkczhw+XKi69zGWw3Fo/7rsKOVXX+vNEAS2ohQUSqb+5kGwoegMo7LbO8ABZHCCBdN5bX4e8sk5vWGCuWiqhoEXMNinloCBY3hUA0Aa/nbxYM//dw1zSI8H8bI25B9h7JqtBBxnqPBx0Duz+7uNgNf9RRRyVLKoVmP24WciSrOxa8Vaw7az7c2+a8iD/cubOncdJJJ6UhGRYeNfXi9Cy2tYt8PUIWZF+4cGHqx3eT65H3GnbBcbP73oJUgotv0Tv22GOTC7/nnnsmJZ6XU8piNxIG0rID4ol9WU0xMovHvWV5s6VrF24W53V4/tBjpEXE4oB8NoWw8Fhw9t9//9RKO9LiUAf5erw2q++8RlTTAPzf4lYDLSnSFj6PQ/h843f6uk2GxY+AyaPzHgiTeD/+99IxsGR3I/uAudRIL76WExfjiqmRo9MbIBNluGMoeBhianu7qcLj+tsgAildw3DPaRfOgfRcduf0Grk1l6BnsbHgERZpCwp0iH+dLnpNhf+TpmKqEO1i0qRJKaNhgevF+9zvGFiygw/YDW2Fd1OLtcXMYniLwGj5914AqXgUV1xxRZpr56az1xuiE/LGA63/M6Jb2LI4ybugFagszLqD1OGguLiKjixo3muiJqJLvY7nZ9xPGGiyZ7BiPnyWTxznZmcF3OQsoZjZzd6LG14c7dxScDwJeXENNohmCg6yy6WzxOMJ5/d/IT33XvqOxXcdFiAFO7lDz+LgvbAg9ivpvef+HwsrwktVTpkyJQmzg7CQ9QJFkN2HjfDSVQQxN/odd9yR3GsWL5fH5pi7k5sjk5x14UrKgdsjTryMTHvttVe18847V8svv/yEus7+F6/HgltkxPPEOvG89JQFiaeTF0NHfh/6CbwowpxmJp8pUVIj00S+101HEWQHNy/ScWulY0AxCiEn58izpe9ENBObK9OV2507d26aXyenrnqLIszKLOn0j9f2PyK9eFYo4RoVHylAAddoYewH15cYlz8nC6wiJSXIWROxsPfbojWeKIbs0OraIn0Wr6SspKu4f3LUculudovDSMR3o7HiLCQvQZeazR/ld5XSKnmdNm1asjCaa5S1NuHGcw2sd/Z0uPhcXQJe3mSSoDnRHki7yMKnz9Nn5D1HdBkX6UbvuYUtsAgDm2cfCyy5GM9NguxudHlw5LcYELXEuL5n7XJM6+9cRjeV51P5EURqy2KCJG40TS08CO5x1gSahhx28EgsWEqACXf+h7zXnYXA4tC066eJ0EO46tx2166GQUnxvvvum4qOXHdgEYole4Z/H/FZdq4s8c7Nr9ac9eDWI7uvmewIguxiXQSn6nOJ5fC57W5Ai8OSdNnbgfdAYY5iH+lBC58Yf+bMmel/4QX5uUn/j0VJwRTPyWfi2nUgHnrooWmhslA3cYFdkiie7Bn5hhf7qXXnznPR/YzYrLmbyg3PzeUGs3o8gFyzbuQRK96vsOjRMWQPWE2TdLjEYmDeigWvKQTihR100EHJklucTfYVlh1xxBHpaw7B+mXBnQgE2VvgrRCLIzXiOxBArbXvM9ndRAiP2FxFFj/fWP1sTfz/vBminbltZ555ZorliXkyCdz6vJjl28ZX//NE/98W5Pe///0plWhB5n3sscceyaoLr3hWNImw8IsQZB8DyO8t8jXDzYPYg2o1LGysuny89CHhUgMOK68lV4zs/eDxOPzNIjCR74cMim62PNBDqKVBiHclx66ewayAEOkWIXycMeAGZrlZ8HwMunvo/0MagzNYdGIjQYzaLaangtMrlKSK7y0MvJ5OYeHgUfCgshc1lg3yd8+hn2TvC+ldd+4G5MoHFiEse2BUIBFCL1iwINUksOAsvFgZ2e+88840N18dPos/FiwKWeDkgls06CFeJy8YFhtEddAJ6CMsd6v3QFOZOnVqmsrr79R302i22267JJR6fLjviyPIHhgTSMhdRvhTTz01EZCbjHjIzvoT8aQZR4JzIDYvgAuuRp+oxktQ14DsOVRyXkRnpVX7mepLIBSDq1fwN9VyiG3hmDx5cjV9+vREdo9pcn3AkkSQPVALWbjTomvctfoCtw5iEu4OO+ywlJEYDtxyXoAGFbXriM6yWwAQU1ZD+lJ87ZwWhZwF8VyhE6suFkdqM/s8d9asWUlDsNDY2MMiEEQfGUH2wKhweyC6mB3ZpbwMyJDT9j0iIvkJJ5yQcvIIl58nF672XsGOwiULBLDYCnccvAHufytREZxrn19Tao03ICZHevE4MY5VR3ZWP9z2sRFkD4wKtwdSqTlAXHl4zT0UemRm2ZH16KOPTrE7IiIl622HG/E+q46I0mNqErjm3G1EV6wjRh9KVK9LeEN4Vt45xOdIr2LR3xBcqo0bj/x5oQkMj6Jq4wOdQSzNtZa/ZmXF3ciIoEjnbwpZWHiWXjedPgENNkQ4MTdS5ik6HofouQR5OIvsdxYBC4nHat6xUKjlV8LsNYl0OfXmXE0t7W0KwrIHasFtIk7m0isl1mGmRJX1ZnG50h/84AeTMs7NZ/2l7ij3LC8iImG3RHQdFh/XQD8w6ouHYVfdQw45JGUFFDxl1T6wCEH2QFvIZMsWPY/7MgHI94gvltcjYP6bnDfr3GvyWXiEFxYdHYfGaPMgEF4xTVO6DJuEcOMDbQGBEJe7zHVmySnpyCY+J65x11l07vp4EB3yNXDplcT6WQejlBw9gCcRMfziCF8n0BXkxxFdLT2rLybXiTZeFr0VFh4LDUtOExA2uBaWXveikCOwCEH2QMfgRhPJ5s6dm9R3ltx8fkQfTmEfL/AuVM1ttdVWSfG3Y494nniXC3UCQfZAh0AiKTAltKbzSH0pdkE6RK8L3oD4u1vpSFWfRUb4QKAzL991UepDlvr/CLIHOgKrzmWWXpMaU9mWN72oCyQUBhD2pOi6JSUNQWEPl54L79oMF6UjBILsgQ6hyCYXzKhRN6pagUw7rnu26tJ3tlbWUtuN200fYOGV70rBGRWG8Lk0t3RE6i3QNhDS5pRq3Vnlgw8+OLnQiDYcPF6zi9FfrQ0vbj0pPGRXKae+Pde5dxPzO6dJO9x455Bwoiew/CUjLHugbSBmtuqIqapNnDwSkNtzkFqxDfEsH2J+Qhr1XAed9Fm3brf6eZYdwXXVCTfE7qUjyB5oG1pTjasC2yuNNYySdZXzVvJKyFP8kg+ElCuXG9cY4/teqPim5zg3D8Eg0VziWzLCjQ+0BbfL7Nmzq3nz5qXutQMPPDAReKx8uphZJ1vr7eZ3LD53W06eko+gwoFuCc+VV6NvrBbrfswxxyTxrh0BcdAQlj1QG9xxJbI636jdOtwo8XWIqeEFofPBG9C1xgLbgFF+3qIhru6FZVddl6f/ChOkCWUQSkaQPVAbLDF3mKVEJq434o5FTtacpdWbrkVVGKB5xeLhPM5hEaHw66zLAl43cE0WFNdIA1Cvr3mmZATZA7WB7AiDkFxte6kh61jwPHl0VXaGXsyfPz+JdVJjWaU3q55glwW6XkSXdAJ1867VAqUvvmQE2QO1gYSITtlWoprbVsdCdv9Zbl1q4n0ltoQzIQHyO0y0YYF7YdmBKq92Htl5FRaWkiWqIHugNpBQgYp4neuN8HXA+tseyzTYvffeO/W+i6MNuHDYhOLII49M0248RhqvF3E7nQDRnc91W3B6tZD0I4LsgdpgFfPIZwSu48ID4mbiKb5RVut76Tvn8jsqPE/B7zPRc6wvBPC4dq2yDAH13WtboEovmw2yByYURDPENsbKwsHNlsIb2nuOnFx6rr04n1Xu1AVvXTxKRpA9UBtIkzdqQEZHu0A4lpZX4Dx+Ho6E8u9q5Yl2uXutXRfc41lzXgELX3KOHYLsgdpATm42oiqQqZu3RmbEc0jdUeWp48jne6IfK5+JCVR/6rn8uFFXnSj0Xs95LUqum/fQCy2gXxFkD9QGi5zHPRG8EHIsAvq7hQGhPZ5LjnwGRDrk3A3AyJNiNdYgqdJa9e1KaDu1yF6HhyAEcN2KeMaq9BtkBNkDtZHJrn6dWy2dlS3xSEA0M+ZPPvnkVGZrMKV4Xc+5cVIWAzl2o6T0nufFA8F5EL4iaCcW2SJj8XANFg3aQMkIsgdqA9kJakpQEUnXWx1XnppObHN4rk0YHVR5c+NYclqAAph2e+JHgkXDgiQMcD4bU1ikSkaQPVAbLKz8ukEVLC6yi6tHc+VZZ5bcaGmHybM2e7BoqIU3RoqV32ijjVL+vU75bR2I8ekDCnlU+llknLtkBNkDtYGErLsJshpglLsa3cwyjwQWWzfbXnvtVe2xxx7pe9bbebjVJtHaXMKmjb3cr43nof7egqRox/XWrQsYVATZA22DC846K3ixQQRXfjTCLwkQAwl+rPsmm2ySFpZeLST9iiB7oG1wwddYY43kHpsww4J2knOvAyHCaGHCSJDSk6fX9SY84E2UjiB74P8Auajs4l7WW64659UdiJ03WeTKa2jpdZNJTps5vDZtwHXUeQ1xugwAJZ4LT5wbbWxWKQiyBxYDMiEVorCOdnrRqYbQNl8wj11Vm6IYgpdiFZNh5ct7MQ46A8Edzo+syE5dtwCNBKGEa8jddKy5XL5932gEpSPGUgUWA8IogNGkkgdActNZ2GzZWXuWXQOL28dOqlT1KVOmpHieKNct5PClzRTvZMiVE9osAEPhOlyb6z7xxBPT80zA2W+//SJefwaxsWPg/0BaTZUcy81S3nfffWnAhCkzCKhIRbrMLq3ScKy6x3ieVlYxfbfkMp4KSZ1P2szhvBaS4c7N7ee+2y6ajrDxxhtXu+22W9oOquSquVbEuxBYDIgkRSUNRtxi6VnJ7EI75Mpt9aQoRirtoIMOSsTkCXDxWeVuHUYEdR1i7XxwxYcjumt84oknqgULFqRyXDl7G1forgv3fRHCsgcWAxddvC6ldt1116XyVjE84iE6V3rPPfdMhKLGI6GhkR7DsiKdx6lWG42gvQIvg+tOS6ApuKaZM2emSbI8gcAiBNkDz4IF55KbD6de3QYO3Gaz4dXEi9tNgeW+c49z3TqXnxfAlVZkY1MGaj6yc8dHcr27AWvuevKitHDhwnQd73vf+9JClPdsDyxCkD2QiEP5tsuLLZ0cctTq1mfNmlVNnjw5ueWEO1ZdN9rQslYjqlh4vxMz22PNOT2OGo74vSI8z0GxjEyBaxU+CCNcG1HOwmQhCiyOUOMDKUfOmh933HHJsutGU9qq8gxZuecIpa98//33f7bcdSiQkLW1aBx11FGJ9Nz+adOmVTvttFMifbfgMWiHNaDSBhCw8847J4uueMaiEhZ9eATZC4WPnXXkBtt/zTQY1pA1J775ykKyxshOhUciLv1oold2rxFdKKC4xc9SZmrq7Q1HYRfTS6GN5OK7PgcNQUbAtcrtO69rVsxj8XCtjuxtBEZGkL1A+MgRx3w3cbniGURhyRWhyJ9T4zOQDelYzTqEcn4WWOwuZWcbJmk7CwGvAPEJaVxvhHVOpLfY5OcKAWQAkJxgaLHJE23oADQDC9Laa6+d0nKl79BaB0H2guCjRhadYNJUFGwEssHi9ttvn9x3RES8VuRbpN2YG7mRVR97Jr7mFJaeyw9eC9lZ+RxnE/oQPRfweF1eBU9DXt91Kt5RrkuUC7e9HoLshSATXdx9ySWXVGeeeWYiCpLvvvvuqSJuvEjjtRE/LzRIr6SVwMbiI7S/t96KroXVN55KfTuFXV6fN2CB8PdeCX6lIMheCFhK8fNFF11UXXzxxcmay0erhDNcYiJ6vd1qXHTkdiC/PDnrL+3nZxAuiOmFEjmu9zvXGCTvHEH2AqC3m8tOhFOAwlqqMCOYyY8PddsnCm49Fp3bzq23EAABELEzuQO9QZB9gIFA0mqXX355de2116b4XHoK0anqKsyinLQcBNkHFNxkNepi4zlz5iRF3Qy4XXbZpVpzzTWT1Qx3uCwE2QcQ3GGpKoUyJ510UrLu06dPT4UtGleC5GUiyD5g8HEiutSaPnM5atVwW265ZVLcJ0KICzQTQfYBgo9StZttkMXo1G3WXIWZaS9ReFI2QuocEBDjjG5SM051R3Rjmk2PycUngbIRZB8AiNGl1/RzIztxTukrq66stJcdZ4H+RZC9z8F1V5BiLNSpp56arDtrrgtMjB4IZATZ+xwaRm6//fakumtomTp1anLfxeiBQCuC7H0KFp27Lo+uBNawCUTfaqutUoweqntgKILsfQhEV2aqk4zqbhSUHVoMcVDzrp48YvTAUATZ+xAEObE5ohvmwGU3Nllvt3bRIHpgOATZ+wysun5wgyeuvPLKZMW33XbbVDQTJbCB0RBk7zMY6EB5P+KII1J7qGmvBkIuqc61QP8gyN4nYNER3eAHE1V1sFHd7WtuukxY9MBYCLL3CQhyprrY8YRlVwJLeZdLD6seqIMge59Ai6pxzoY3GtekucWYJvPaA4E6CLL3Aajv5rnbeEG1HEFOrG5sU7jvgboIsjccGlzE52eccUYaFmnwhLlx472HWmDwEGRvMFh0e6PnLY7MjrMVsTnpgUC7CLI3FNR3E2aIcabBKpahvKuUi770QCcIsjcU3HcbKmhbNRHW3PQNNtggbZ0U7nugEwTZGwpW3SYKd911V6p333HHHdO2TDFaOdAp4s5pILjwat7F6Sz8vvvum4ZQRCdboBsE2RsGRLdF0nXXXZfceM0t6623XtoQMdz3QDcIsjcIquSUxCK6rZoUzBgvZdeWvOlhINApguw9AotsmETrNsM2abCPuO/lyvN+ZlJqQ+H5ps7oTb/iiisS6eXUbdEU3WyBXiBGSXcBllhMjeTIieT6zH31s997e1ll011VvNmF1KaF3HI17YhMdHMuY6DtsDp79uykvttdVbotRLlALxBk7xDIyVI/+uijaQbcvffeWz3++ONJRdd6yoJbCABZkVrvOZJzy+0zbjsmeXP7jrPqd999d3X00Uenc3/4wx+uttlmm6h9D/QMQfY24K1CYFVtlPIbbrihuv/++9MYZ2C1qeZy4ZpVWHPut11KNbJ4ngWBAMfys/TmxbHinmNg5Omnn17tv//+aZ7ciiuuGFY90DME2WuC5UVSFW06zwx4FJNzxRHWaKjll18+lbLqL89uOrJbIFh7cTvCP/HEE6nO3cHtz4/1GFNoPv/5zycV3mIQCPQKQfYxwKVmuVld7rrhEQQ3cbdNEm19vOqqq1bLLbdc7flvXHwkt2DIp1tAHnnkkRTrL7300tWee+5ZbbjhhmnhiF71QK8QZB8F3hoEVLI6b968tCsq13qHHXZIDSlcdjF1p6628/MYKPYLFy6srrrqqjScgodggKThFL6PtFugFwiyjwAWXRrNJonnnntu6iffdNNNUzUbcQ3JtZl2G1N7+6XiWPvHHnss7b5qZxfYfvvtq7333jt1u0XsHugWQfZhgHwENd1mct6+F0PPmDEj1adz13tNPh8DwovpjYi++uqrk9ov9SYFRxOwZ1sg0Cme9z9P45nvA08D6cTT119/fdppBbjsXHfxOcKNh5UV63PXiXKUeUp+juuRXshAyONNBAKdIHzDIaCG6zY788wzUyzNsnKnzXubiOkwFhJkN41GzI78hlfcdNNN1VNPPZXCi0CgEwTZWyBFpnec627em/5xhS1i9ImMmS0oyyyzTNII5NstQGeffXYS8eTnA4FOEGR/Btx38TI13GaJiL7LLrssscIWhJevl4IzSda13XjjjSlVN1xtfSAwFoLsz4BVV/Iq5811PvDAA9PQCLFzXXCxxdm61qj3imS6gbBBaa3BFZMmTUpFOIZZ5Iq9QKAdBNmfBqtOcVcwg0jrrLNO2lJptBx6LoFVC58P7rbnq5A755xzUl5e1VynYN0tNiuttFIaH03AU55rs4hAoF2EGv80WGQVcpdddlkiFEGO8j4S0S0OOtSUzSIfK+5QBeeruP+CCy5IXW8s87LLLttxYUyrSk+Zt5BI/bH04y0WBgYLkWd/GqrY5syZk3ZFXX311asDDjggpbpGgpiZq26WO4K3QjjA4lPOLRhy5LvuumvqbOsGhDnFParsEP8rX/lK6qJbEnpCoD9RPNlZdS64yjhEnT59eiIn6zkSvGWaYFS8IXaGc6mblyYTX9t4UfqMG96pZc9w7ltvvbU677zzkudw1FFHVausskoifCBQB8WbBVVriCkO1tzCoo8lynGfPRbZuNP5MFVGj7p02cc+9rFqu+22S5VvvSiE8ZqabTTHsPLq9ZXzBgJ1UTzZWeYHH3wwKec5vq5DTsMobNagoi4PpuBeez7iG0ohV+5vvUBeYBTc8BLuuOOOJAhGFBaoi6LJjiiZ7IDsBlCMBc9DNFNqiHRaX82Os2D4G2GOcCeVR1ATHvQCQovcK0+sE3706tyBwUfxlp1LjKisOSLVGQMlfjZxhkhnssxxxx2XlHxKPGHummuuqU455ZQkqFkMhAq9AIu+1FJLJbFPPl+ar1fnDgw+irfsudMMkRC9zj5qFHDuOlfd4AoWXv16rmHXpsqF1yHXS1fe67LuRl7xSDTIyPcHAnVQvGVHFi65mJswVydeFz+Lzwl0FHfCHCsrLSaWlqf/wAc+kDrlTLPpZdxOI0B4i5QjSmcDdVE82bnkYmyEbEc193gWdoUVVqi22GKLFEdLu4G6eiQ3dIKoxiIjJRdfLJ9n2Im52yWra/Tarlu8HgJdoC6KJztriYzI0y5xPJf7v/LKKydLj4gWACJf68LBe5CX17UmZSam11lndDRX3GvXhWv0+HzdgUBdFH+3ICUXXhUdUrZDeI9lXVlo33s+d965MoH9HqER2/hp+gBBULWetlVqfTtxt9fjvmcL37qoBAKjoWiys46mt3K1EUjqrA7xEDiTm8U2jFK8T4xT3Wbohb8hpq+aYRB9p512qg466KB0TJs2LdXjc/2FEXXgNWUPLB4WKEeQPVAXxVt2JDXBVezMKreWvw4HFtvuqlxy/eXm1Hm+qjlz6hDb76Tl9MVL0VkE9MZrmeXuey2Li5+l0eoKeBaP3FknXHDEqOlAXRRPdsq2FBkSs8hGR48GRGWlueDnn39+yq2vu+66aVcXjS+sLRddeyv3HhmJd1J0LLN+eTG7dlpFPF6/bt28BUJ+3TUqw5Vzr7tQBALFu/GIaL6c77WtZkV9NHDNeQAWCCRWHkuVl3fffPPN00YPLDaLr1CHkObwHF1ytmN+4IEHnl1c6rrx3Hcxv3BDdx7LHm58oC6K73pjqRFcThw5jYDiko9mbQlwiMet5qIjssXCz4iIwKx2q4ttYWCVxfQKb1h3Qt1nP/vZlKvX4DIafEzKb3kUvIZPfvKTaUYdwgcCdVC8G4+oXGmTaVhY1XDc79GgsEV6DdFZVkQHCwRPQXfaUPfa63iOuN7GjYcffnh6nBic6DbWmmshoRWw7LwFhTxChkCgLoonO6Ii72abbZYsMTcbqUaD5yDvcC50/n1eAFrhb15LrM2S62BDWL8fCxYFDTtCAS68UKFurB8IQPFkB+QkrnHjVbkR4HoFFhtRCXJ65pE1hwGsP8LzBoZbHDKcQ8WdQ/aAIGjRGO05gcBQBNmfBtJobGExQZUbYlK/u5U0xOrScWJ0DTIWEptQGJgh9WZU9Widdtx3GoC94Al6hmsQFAOBdhFkfxrIzmKKg+21rthFWg252q1dHw4suUo5e7gpwJk/f35aAPbee+80ssprDwcLBQ+AVyCdpyNPzE/8CwTaRQycbAF324aKZ511VlLNP/OZz6QmF1a/G5eZ8Mc6E/7E2eJ0oiCSjxTf+1i4/BR4wyXl82fNmlXtvPPOsatroCPEHdMCBFxvvfVSiyqLrj+d242o3YDwR0HnNVDg87QZxB9pEZHLf+ihh6pLL700WfYpU6YkEVGMH0QPdIKYG98CJEL4HEPbaslQR8QcmjdvF0jNijtGIyuLzhNAdF6GkVfieiOpV1tttXR93XgZgXIRZB8CZER2BFP0wp1Xi8719nuER9bxIBxvguuO6AZh8Cq8rkIfU2vV1QfRA50iyD4McnEMa66E1v5qrLzfyZGPFmt3Cqq7RcVwC1tHqbJTHTdz5sy0k2xY9EC3CIFuBHhbqOiKbMTuOtlYXmW1CNjLcVNey4YT119/fRIHFfXYn33GjBlJQ0D6iNMD3SLIPgq8NUpZtamaLXfhhRdWTz75ZOo40+WmNp27LyXGG6hLSOeVVrOYSO85N0vOe/A3AiExzgScIHqgVwiy1wBiapaxdztSSoMhtwIX7bE63qTnDMHg6iO/2J6rn5/PTbdwEPyk+JxPYQ1NQAGPtJxOufXXXz9tGeXcdSbdBgJ1EWSvCYTNKrnBFYZGIqu3D8ml09SrO8T1hLWcWuP+s+JIrvONNTfhRkoP+TXUUNptFa2whxDouYFALxFkbxPeLm2sCMsqE++0nIq5LQbEO7F8TrEhu+cgvNy5ElwLhwUi7wOvJ15oEC57YDwRZO8AmbxccwRHYOTn3rP8Unbc9LwdVO50U1DD7ReLK7CxMHD3WxeHQGC8EGTvERA+x+O++tmC4O1FZIQWg3PR5ct9H6m0wEQiyB4IFILwGwOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoBAE2QOBQhBkDwQKQZA9ECgEQfZAoAhU1f8DOy5IPdip4X4AAAAASUVORK5CYII="
    }
   },
   "cell_type": "markdown",
   "id": "b38f0f39",
   "metadata": {},
   "source": [
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "a394bfd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义Data实例的属性数据\n",
    "edge_index = torch.tensor([[1, 2, 3], [0, 0, 0]], dtype=torch.long)\n",
    "x = torch.tensor([[1], [1], [1], [1]], dtype=torch.float)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2428f500",
   "metadata": {},
   "source": [
    "x has shape [N, in_channels], N表示节点个数    \n",
    "edge_index has shape [2, E], E表示边的个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "0c352ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree\n",
    "\n",
    "class GCNConv(MessagePassing):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super().__init__(aggr='add')  # \"Add\" aggregation (Step 5).\n",
    "        self.lin = torch.nn.Linear(in_channels, out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        # x has shape [N, in_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "\n",
    "        # Step 1: Add self-loops to the adjacency matrix.\n",
    "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
    "        print(\"\\nAdd self-loops edge_index--------------\")\n",
    "        print(edge_index)\n",
    "\n",
    "        # Step 2: Linearly transform node feature matrix.\n",
    "        x = self.lin(x)\n",
    "        print(\"\\nLinearly transform--------------\")\n",
    "        print(x)\n",
    "\n",
    "        # Step 3: Compute normalization.\n",
    "        row, col = edge_index\n",
    "        print(\"\\nrow-------\",row)\n",
    "        print(\"\\ncol-------\",col)\n",
    "        deg = degree(col, x.size(0), dtype=x.dtype)\n",
    "        print(\"\\ndegree------------------\")\n",
    "        print(deg)\n",
    "        deg_inv_sqrt = deg.pow(-0.5)\n",
    "        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n",
    "        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n",
    "        print(\"\\nnorm--------------\")\n",
    "        print(norm)\n",
    "\n",
    "        # Step 4-5: Start propagating messages.\n",
    "        return self.propagate(edge_index, x=x, norm=norm)\n",
    "\n",
    "    def message(self, x_j, norm):\n",
    "        # x_j has shape [E, out_channels]\n",
    "        \n",
    "        print(\"\\nx_j--------------\")\n",
    "        print(x_j)\n",
    "        print(\"\\nnorm.view(-1, 1) * x_j--------------\")\n",
    "        print(norm.view(-1, 1) * x_j)\n",
    "\n",
    "        # Step 4: Normalize node features.\n",
    "        return norm.view(-1, 1) * x_j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "41dfc455",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Add self-loops edge_index--------------\n",
      "tensor([[1, 2, 3, 0, 1, 2, 3],\n",
      "        [0, 0, 0, 0, 1, 2, 3]])\n",
      "\n",
      "Linearly transform--------------\n",
      "tensor([[ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327]], grad_fn=<AddmmBackward>)\n",
      "\n",
      "row------- tensor([1, 2, 3, 0, 1, 2, 3])\n",
      "\n",
      "col------- tensor([0, 0, 0, 0, 1, 2, 3])\n",
      "\n",
      "degree------------------\n",
      "tensor([4., 1., 1., 1.])\n",
      "\n",
      "norm--------------\n",
      "tensor([0.5000, 0.5000, 0.5000, 0.2500, 1.0000, 1.0000, 1.0000])\n",
      "\n",
      "x_j--------------\n",
      "tensor([[ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327]], grad_fn=<IndexSelectBackward>)\n",
      "\n",
      "norm.view(-1, 1) * x_j--------------\n",
      "tensor([[ 0.4005, -0.0664],\n",
      "        [ 0.4005, -0.0664],\n",
      "        [ 0.4005, -0.0664],\n",
      "        [ 0.2003, -0.0332],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327],\n",
      "        [ 0.8011, -0.1327]], grad_fn=<MulBackward0>)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.4019, -0.2322],\n",
       "        [ 0.8011, -0.1327],\n",
       "        [ 0.8011, -0.1327],\n",
       "        [ 0.8011, -0.1327]], grad_fn=<ScatterAddBackward>)"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = GCNConv(1, 2)\n",
    "ret = conv(x, edge_index)\n",
    "ret"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90a29305",
   "metadata": {},
   "source": [
    "# 2. GCN的简单实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "819069a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "\n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = GCNConv(dataset.num_node_features, 16)\n",
    "        self.conv2 = GCNConv(16, dataset.num_classes)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "9fe972c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "dataset = Planetoid(root='./data/Cora', name='Cora')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "c06386b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = GNN().to(device)\n",
    "data = dataset[0].to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(200):\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data)\n",
    "    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "d9806c9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.7980\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "pred = model(data).argmax(dim=1)\n",
    "correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()\n",
    "acc = int(correct) / int(data.test_mask.sum())\n",
    "print(f'Accuracy: {acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7396a9d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
